{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "992c4695-ec4f-428d-bd05-fb3b5fbd70f4",
      "metadata": {},
      "source": [
        "# How to add memory to the prebuilt ReAct agent\n",
        "\n",
        "This tutorial will show how to add memory to the prebuilt ReAct agent. Please see [this tutorial](./create-react-agent.ipynb) for how to get started with the prebuilt ReAct agent\n",
        "\n",
        "All we need to do to enable memory is pass in a checkpointer to `create_react_agents`"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "7be3889f-3c17-4fa1-bd2b-84114a2c7247",
      "metadata": {},
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "id": "a213e11a-5c62-4ddb-a707-490d91add383",
      "metadata": {},
      "outputs": [],
      "source": [
        "%%capture --no-stderr\n",
        "%pip install -U langgraph langchain-openai"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "id": "23a1885c-04ab-4750-aefa-105891fddf3e",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "OPENAI_API_KEY:  ········\n"
          ]
        }
      ],
      "source": [
        "import getpass\n",
        "import os\n",
        "\n",
        "\n",
        "def _set_env(var: str):\n",
        "    if not os.environ.get(var):\n",
        "        os.environ[var] = getpass.getpass(f\"{var}: \")\n",
        "\n",
        "\n",
        "_set_env(\"OPENAI_API_KEY\")\n",
        "\n",
        "# Recommended\n",
        "_set_env(\"LANGCHAIN_API_KEY\")\n",
        "os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
        "os.environ[\"LANGCHAIN_PROJECT\"] = \"Create ReAct Agent Tutorial\""
      ]
    },
    {
      "cell_type": "markdown",
      "id": "03c0f089-070c-4cd4-87e0-6c51f2477b82",
      "metadata": {},
      "source": [
        "## Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "id": "7a154152-973e-4b5d-aa13-48c617744a4c",
      "metadata": {},
      "outputs": [],
      "source": [
        "# First we initialize the model we want to use.\n",
        "from langchain_openai import ChatOpenAI\n",
        "\n",
        "model = ChatOpenAI(model=\"gpt-4o\", temperature=0)\n",
        "\n",
        "\n",
        "# For this tutorial we will use custom tool that returns pre-defined values for weather in two cities (NYC & SF)\n",
        "\n",
        "from typing import Literal\n",
        "\n",
        "from langchain_core.tools import tool\n",
        "\n",
        "\n",
        "@tool\n",
        "def get_weather(city: Literal[\"nyc\", \"sf\"]):\n",
        "    \"\"\"Use this to get weather information.\"\"\"\n",
        "    if city == \"nyc\":\n",
        "        return \"It might be cloudy in nyc\"\n",
        "    elif city == \"sf\":\n",
        "        return \"It's always sunny in sf\"\n",
        "    else:\n",
        "        raise AssertionError(\"Unknown city\")\n",
        "\n",
        "\n",
        "tools = [get_weather]\n",
        "\n",
        "# We can add \"chat memory\" to the graph with LangGraph's checkpointer\n",
        "# to retain the chat context between interactions\n",
        "from langgraph.checkpoint.memory import MemorySaver\n",
        "\n",
        "memory = MemorySaver()\n",
        "\n",
        "# Define the graph\n",
        "\n",
        "from langgraph.prebuilt import create_react_agent\n",
        "\n",
        "graph = create_react_agent(model, tools=tools, checkpointer=memory)"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "00407425-506d-4ffd-9c86-987921d8c844",
      "metadata": {},
      "source": [
        "## Usage\n",
        "\n",
        "Let's interact with it multiple times to show that it can remember"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "id": "16636975-5f2d-4dc7-ab8e-d0bea0830a28",
      "metadata": {},
      "outputs": [],
      "source": [
        "def print_stream(stream):\n",
        "    for s in stream:\n",
        "        message = s[\"messages\"][-1]\n",
        "        if isinstance(message, tuple):\n",
        "            print(message)\n",
        "        else:\n",
        "            message.pretty_print()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "id": "9ffff6c3-a4f5-47c9-b51d-97caaee85cd6",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "================================\u001b[1m Human Message \u001b[0m=================================\n",
            "\n",
            "What's the weather in NYC?\n",
            "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
            "Tool Calls:\n",
            "  get_weather (call_mdovy4yXSSYrmSlnlVSUacVn)\n",
            " Call ID: call_mdovy4yXSSYrmSlnlVSUacVn\n",
            "  Args:\n",
            "    city: nyc\n",
            "=================================\u001b[1m Tool Message \u001b[0m=================================\n",
            "Name: get_weather\n",
            "\n",
            "It might be cloudy in nyc\n",
            "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
            "\n",
            "The weather in NYC might be cloudy.\n"
          ]
        }
      ],
      "source": [
        "config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
        "inputs = {\"messages\": [(\"user\", \"What's the weather in NYC?\")]}\n",
        "\n",
        "print_stream(graph.stream(inputs, config=config, stream_mode=\"values\"))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "838a043f-90ad-4e69-9d1d-6e22db2c346c",
      "metadata": {},
      "source": [
        "Notice that when we pass the same the same thread ID, the chat history is preserved"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "187479f9-32fa-4611-9487-cf816ba2e147",
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "================================\u001b[1m Human Message \u001b[0m=================================\n",
            "\n",
            "What's it known for?\n",
            "==================================\u001b[1m Ai Message \u001b[0m==================================\n",
            "\n",
            "New York City (NYC) is known for many things, including:\n",
            "\n",
            "1. **Landmarks and Attractions**: The Statue of Liberty, Times Square, Central Park, Empire State Building, and Brooklyn Bridge.\n",
            "2. **Cultural Institutions**: Broadway theaters, Metropolitan Museum of Art, Museum of Modern Art (MoMA), and the American Museum of Natural History.\n",
            "3. **Diverse Neighborhoods**: Areas like Chinatown, Little Italy, Harlem, and Greenwich Village.\n",
            "4. **Financial Hub**: Wall Street and the New York Stock Exchange.\n",
            "5. **Cuisine**: A melting pot of global cuisines, famous for its pizza, bagels, and street food.\n",
            "6. **Media and Entertainment**: Home to major media companies, TV networks, and film studios.\n",
            "7. **Fashion**: A global fashion capital, hosting New York Fashion Week.\n",
            "8. **Sports**: Teams like the New York Yankees, New York Mets, New York Knicks, and New York Rangers.\n",
            "9. **Public Transportation**: An extensive subway system and iconic yellow taxis.\n",
            "10. **Events**: New Year's Eve celebration in Times Square, Macy's Thanksgiving Day Parade, and various cultural festivals.\n"
          ]
        }
      ],
      "source": [
        "inputs = {\"messages\": [(\"user\", \"What's it known for?\")]}\n",
        "print_stream(graph.stream(inputs, config=config, stream_mode=\"values\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3decf001-7228-4ed5-8779-2b9ed98a74ea",
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.1"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
